import torch
import torch.nn as nn
import torch.nn.functional as F


class Encoder(nn.Module):
    """Encodes a node's using 'convolutional' GraphSage approach"""

    def __init__(
        self,
        features,
        feature_dim,
        embed_dim,
        adj_lists,
        aggregator,
        num_sample=10,
        base_model=None,
        gcn=False,
        cuda=False,
        feature_transform=False,
    ):
        super(Encoder, self).__init__()

        self.features = features
        self.feat_dim = feature_dim
        self.adj_lists = adj_lists
        self.aggregator = aggregator
        self.num_sample = num_sample
        if base_model:
            self.base_model = base_model

        self.gcn = gcn
        self.embed_dim = embed_dim
        self.cuda = cuda
        self.aggregator.cuda = cuda
        self.weight = nn.Parameter(
            torch.FloatTensor(embed_dim, self.feat_dim if self.gcn else 2 * self.feat_dim)
        )
        nn.init.xavier_uniform_(self.weight)

    def forward(self, nodes):
        """Generates embeddings for a batch of nodes.

        Parameters
        ----------
        nodes :
            list of node

        Returns
        -------
        [type]
            [description]
        """
        neigh_feats = self.aggregator.forward(
            nodes, [self.adj_lists[int(node)] for node in nodes], self.num_sample
        )
        if not self.gcn:
            if self.cuda:
                self_feats = self.features(torch.LongTensor(nodes).cuda())
            else:
                self_feats = self.features(torch.LongTensor(nodes))
            combined = torch.cat([self_feats, neigh_feats], dim=1)
        else:
            combined = neigh_feats
        combined = F.relu(self.weight.mm(combined.t()))
        return combined
